import math
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
from linearmodels.iv import IV2SLS


MISSING_CODES = {7, 8, 9, 77, 88, 99, 777, 888, 999, 7777, 8888, 9999}


def clean_numeric(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    s = s.mask(s.isin(MISSING_CODES))
    return s


def recode_yes_no(series: pd.Series) -> pd.Series:
    s = clean_numeric(series)
    return s.map({1: 1.0, 2: 0.0})


def zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m = s.mean()
    sd = s.std(ddof=0)
    if pd.isna(sd) or sd <= 0:
        return s * 0
    return (s - m) / sd


@dataclass(frozen=True)
class OutcomeSpec:
    name: str
    label: str
    kind: str  # "binary" | "continuous"


def build_dataset(root: Path) -> pd.DataFrame:
    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    reg = pd.read_csv(root / "data_external" / "broadband_region_year_eurostat_isoc_r_broad_h.csv")

    ess = ess[ess["regunit"].isin([1, 2])].copy()
    ess["region"] = ess["region"].astype(str).str.upper().str.strip()
    ess = ess[ess["region"] != "99999"].copy()

    reg["region"] = reg["region"].astype(str).str.upper().str.strip()
    reg["year"] = pd.to_numeric(reg["year"], errors="coerce").astype("Int64")
    reg["broadband_pc_hh"] = pd.to_numeric(reg.get("broadband_pc_hh"), errors="coerce")
    reg["internet_access_pc_hh"] = pd.to_numeric(reg.get("internet_access_pc_hh"), errors="coerce")

    df = ess.merge(reg, left_on=["region", "survey_year"], right_on=["region", "year"], how="left")

    # Endogenous
    df["netusoft"] = clean_numeric(df["netusoft"])

    # Instruments (region-year)
    df["z_broadband_pc_hh"] = zscore(df["broadband_pc_hh"])
    df["z_internet_access_pc_hh"] = zscore(df["internet_access_pc_hh"])

    # Participation outcomes
    df["y_vote"] = recode_yes_no(df["vote"])
    df["y_sgnptit"] = recode_yes_no(df["sgnptit"])
    df["y_pbldmn"] = recode_yes_no(df["pbldmn"])
    df["y_bctprd"] = recode_yes_no(df["bctprd"])
    df["y_contplt"] = recode_yes_no(df["contplt"])
    df["y_badge"] = recode_yes_no(df["badge"])
    df["y_pstplonl"] = recode_yes_no(df["pstplonl"]) if "pstplonl" in df.columns else np.nan
    part_cols = ["y_sgnptit", "y_pbldmn", "y_bctprd", "y_contplt", "y_badge", "y_pstplonl"]
    df["y_part_index"] = df[part_cols].mean(axis=1, skipna=True)
    df.loc[df[part_cols].isna().all(axis=1), "y_part_index"] = np.nan

    # Mechanism candidates
    df["nwspol"] = clean_numeric(df["nwspol"])
    df["log_nwspol"] = (df["nwspol"].clip(lower=0)).map(lambda x: math.log1p(x) if pd.notna(x) else x)

    # Placebo outcomes (predetermined)
    df["gndr_num"] = clean_numeric(df["gndr"])  # 1/2
    df["brncntr_bin"] = recode_yes_no(df["brncntr"]) if "brncntr" in df.columns else np.nan

    # Controls
    df["agea"] = clean_numeric(df["agea"])
    df["gndr"] = df["gndr_num"]  # keep consistent name
    df["eduyrs"] = clean_numeric(df["eduyrs"])
    df["hinctnta"] = clean_numeric(df["hinctnta"])

    # Education indicator: high vs low
    df["edu_high"] = np.where(df["eduyrs"].notna() & (df["eduyrs"] > 12), 1.0, np.nan)
    df.loc[df["eduyrs"].notna() & (df["eduyrs"] <= 12), "edu_high"] = 0.0

    # FE key (country×survey_year)
    df["ct"] = df["cntry"].astype(str) + "_" + df["survey_year"].astype("Int64").astype(str)

    # Keep rows where instruments exist and edu_high exists
    df = df[df[["z_broadband_pc_hh", "z_internet_access_pc_hh"]].notna().any(axis=1)].copy()
    df = df[df["edu_high"].notna()].copy()

    return df


def fit_iv_interaction_ctfe(df: pd.DataFrame, y: str, controls: list[str]) -> IV2SLS:
    """
    Interaction-IV:
      y = b1*netusoft + b2*(netusoft*edu_high) + controls + ctFE + e
    Instruments:
      z_broadband, z_internet_access, and their interactions with edu_high
    """
    cols = [y, "netusoft", "edu_high", "z_broadband_pc_hh", "z_internet_access_pc_hh", "ct", "region", *controls]
    d = df[cols].dropna().copy()

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    d["z_broadband_x_edu"] = d["z_broadband_pc_hh"] * d["edu_high"]
    d["z_inetacc_x_edu"] = d["z_internet_access_pc_hh"] * d["edu_high"]

    # FE via ct dummies
    d["ct"] = d["ct"].astype(str)
    ct_d = pd.get_dummies(d["ct"], prefix="ct", drop_first=True)

    exog = pd.concat([pd.DataFrame({"const": 1.0}, index=d.index), d[controls + ["edu_high"]], ct_d], axis=1)
    exog = exog.loc[:, ~exog.columns.duplicated()]

    endog = d[["netusoft", "netusoft_x_edu"]]
    instr = d[["z_broadband_pc_hh", "z_internet_access_pc_hh", "z_broadband_x_edu", "z_inetacc_x_edu"]]

    mod = IV2SLS(d[y], exog=exog, endog=endog, instruments=instr)
    return mod.fit(cov_type="clustered", clusters=d["region"])


def summarize_result(res, y: str, label: str) -> dict:
    b1 = float(res.params.get("netusoft", np.nan))
    b2 = float(res.params.get("netusoft_x_edu", np.nan))
    se1 = float(res.std_errors.get("netusoft", np.nan))
    se2 = float(res.std_errors.get("netusoft_x_edu", np.nan))

    diag = res.first_stage.diagnostics
    fs_net = diag.loc["netusoft"] if "netusoft" in diag.index else diag.iloc[0]
    fs_int = diag.loc["netusoft_x_edu"] if "netusoft_x_edu" in diag.index else diag.iloc[-1]

    return {
        "outcome": y,
        "label": label,
        "nobs": int(res.nobs),
        "coef_lowEdu": b1,
        "se_lowEdu": se1,
        "coef_deltaHighMinusLow": b2,
        "se_deltaHighMinusLow": se2,
        "coef_highEdu": b1 + b2,
        "fsF_netusoft": float(fs_net.get("f.stat", np.nan)),
        "fsF_netusoft_x_edu": float(fs_int.get("f.stat", np.nan)),
    }


def to_markdown_table(df: pd.DataFrame) -> str:
    cols = [
        "label",
        "nobs",
        "coef_lowEdu",
        "se_lowEdu",
        "coef_highEdu",
        "coef_deltaHighMinusLow",
        "se_deltaHighMinusLow",
        "fsF_netusoft",
        "fsF_netusoft_x_edu",
    ]
    d = df[cols].copy()
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        d[c] = d[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    return d.to_markdown(index=False)


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)

    df = build_dataset(root)

    years = sorted(int(x) for x in df["survey_year"].dropna().unique())
    countries = sorted(df["cntry"].dropna().astype(str).unique().tolist())

    # Main outcomes: participation structure (supports inequality story)
    main_outcomes = [
        OutcomeSpec("y_part_index", "Participation index (mean of items)", "continuous"),
        OutcomeSpec("y_vote", "Voted", "binary"),
        OutcomeSpec("y_sgnptit", "Signed petition", "binary"),
        OutcomeSpec("y_pbldmn", "Demonstration", "binary"),
        OutcomeSpec("y_bctprd", "Boycott", "binary"),
        OutcomeSpec("y_contplt", "Contacted politician", "binary"),
        OutcomeSpec("y_badge", "Worn badge/sticker", "binary"),
        OutcomeSpec("y_pstplonl", "Posted politics online", "binary"),
    ]

    mechanism_outcomes = [
        OutcomeSpec("log_nwspol", "Mechanism: log(1+nwspol minutes)", "continuous"),
    ]

    placebo_outcomes = [
        OutcomeSpec("gndr", "Placebo: gender (numeric code)", "continuous"),
        OutcomeSpec("brncntr_bin", "Placebo: born in country (1/0)", "binary"),
    ]

    # Controls (keep minimal and consistent)
    base_controls = ["agea", "gndr", "eduyrs", "hinctnta"]

    rows = []
    errors = []

    def run_suite(specs: list[OutcomeSpec], suite_name: str) -> None:
        nonlocal rows, errors
        for s in specs:
            if s.name not in df.columns:
                continue
            # For placebo outcomes that overlap controls, drop that control.
            controls = [c for c in base_controls if c != s.name]
            try:
                res = fit_iv_interaction_ctfe(df, s.name, controls=controls)
                rows.append({**summarize_result(res, s.name, f"{suite_name}: {s.label}"), "suite": suite_name})
            except Exception as e:
                errors.append(f"[{suite_name}] {s.name}: {type(e).__name__}: {e}")

    run_suite(main_outcomes, "main")
    run_suite(mechanism_outcomes, "mechanism")
    run_suite(placebo_outcomes, "placebo")

    out_csv = out_dir / "iv_interaction_edu_main_table.csv"
    out_md = out_dir / "iv_interaction_edu_main_table.md"
    out_md_checks = out_dir / "iv_interaction_edu_placebo_mechanism.md"
    out_txt = out_dir / "iv_interaction_edu_checks.txt"

    res_df = pd.DataFrame(rows).sort_values(["suite", "label"]).reset_index(drop=True)
    res_df.to_csv(out_csv, index=False, encoding="utf-8")
    out_md.write_text(to_markdown_table(res_df[res_df["suite"] == "main"]), encoding="utf-8")
    out_md_checks.write_text(to_markdown_table(res_df[res_df["suite"] != "main"]), encoding="utf-8")

    checks = []
    checks.append("IV interaction suite (ctFE; clustered by region)")
    checks.append(f"Years in sample: {years}")
    checks.append(f"Countries in sample: {len(countries)}")
    checks.append(f"Rows (edu_high non-missing): {len(df):,}")
    checks.append("")
    if errors:
        checks.append("Errors:")
        checks.extend(errors)
        checks.append("")
    checks.append("Main table CSV:")
    checks.append(str(out_csv))
    checks.append("Main table markdown (main outcomes only):")
    checks.append(str(out_md))
    checks.append("Placebo/mechanism markdown:")
    checks.append(str(out_md_checks))
    out_txt.write_text("\n".join(checks) + "\n", encoding="utf-8")

    print(f"Wrote: {out_csv}")
    print(f"Wrote: {out_md}")
    print(f"Wrote: {out_txt}")


if __name__ == "__main__":
    main()
